# =============================================================================
#  This example script implements the main protocol section of the paper.
# =============================================================================

"""
1. Read the Video
"""
from MOSES.Utility_Functions.file_io import read_multiimg_PIL

infile = '../Data/Videos/c_EPC2(G)_CP-A(R)_KSFM+RPMI_5_Fast GFP.tif_R-G.tif'
vidstack = read_multiimg_PIL(infile)

# Check the read size was expected. 
n_frame, n_rows, n_cols, n_channels = vidstack.shape
print('Size of video: (%d,%d,%d,%d)' %(n_frame,n_rows,n_cols,n_channels))


"""
2. Motion Extraction Parameters
"""
optical_flow_params = dict(pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
# number of superpixels
n_spixels = 1000

"""
3. Extract the superpixel tracks 
"""
from MOSES.Optical_Flow_Tracking.superpixel_track import compute_grayscale_vid_superpixel_tracks

# extract superpixel tracks for the 1st or 'red' channel
optflow_r, meantracks_r = compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,0], optical_flow_params, n_spixels)
# extract superpixel tracks for the 2nd or 'green' channel
optflow_g, meantracks_g = compute_grayscale_vid_superpixel_tracks(vidstack[:,:,:,1], optical_flow_params, n_spixels)

"""
4. Save the superpixel tracks
"""
import scipy.io as spio
import os 
fname = os.path.split(infile)[-1]
savetracksmat = ('meantracks_'+infile).replace('tif','.mat')
spio.savemat(savetracksmat, {'meantracks_r':meantracks_r,
                             'meantracks_g':meantracks_g})

"""
Visualise the optical flow for checking 
"""
import pylab as plt 
from MOSES.Visualisation_Tools.motion_field_visualisation import view_ang_flow

# visualize first frame of red and green motion fields
fig, ax = plt.subplots (nrows=1, ncols=2)
ax[0].imshow(view_ang_flow(optflow_g[0]))
ax[1].imshow(view_ang_flow(optflow_r[0]))
ax[0].grid('off'); ax[1].grid('off')
plt.show()

# save the flow
save_optflow_mat = ('optflow_'+fname).replace('.tif', '.mat')
spio.savemat(save_optflow_mat, {'optflow_r':optflow_r, 
                                'optflow_g':optflow_g})

#################
#   Optional track plotting
#################
from MOSES.Visualisation_Tools.track_plotting import plot_tracks
fig, ax = plt.subplots()
plt.imshow(vidstack[0])
plot_tracks(meantracks_r, ax, color='r', lw=1)
plot_tracks(meantracks_g, ax, color='g', lw=1)
ax.set_xlim([0, n_cols])
ax.set_ylim([n_rows, 0])
ax.grid('off')
ax.axis('off')
plt.show()

# =============================================================================
# Optional track plotting for each frame to create a video using ImageJ
# =============================================================================
from MOSES.Utility_Functions.file_io import mkdir
import os 

fname = os.path.split(infile)[-1]
# specify the save folder and create this automatically.
save_frame_folder = 'track_video'
mkdir(save_frame_folder)

len_segment = 5

for frame in range(len_segment, n_frame, 1):
    frame_img = vidstack[frame]
    
    fig = plt.figure()
    fig.set_size_inches(float(n_cols)/n_rows, 1, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.]); ax.set_axis_off(); fig.add_axes(ax)
    ax.imshow(frame_img, alpha=0.6)
    # visualize only a short segment
    plot_tracks(meantracks_r[:, frame-len_segment:frame+1], ax, color='r', lw=1)
    plot_tracks(meantracks_g[:, frame-len_segment:frame+1], ax, color='g', lw=1)
    ax.set_xlim([0, n_cols]); ax.set_ylim([n_rows, 0])
    ax.grid('off'); ax.axis('off')
    # tip: use multiples of n_rows to increase image resolution
    fig.savefig(os.path.join(save_frame_folder, 'tracksimg-%s_' %(str(frame+1).zfill(3))+fname.replace('.tif','.png')), dpi=n_rows) 
    plt.show()

"""
4. Construct and plot the MOSES mesh.
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import construct_MOSES_mesh

spixel_size = meantracks_r[1,0,1] - meantracks_r[1,0,0]

MOSES_mesh_strain_time_r, MOSES_mesh_neighborlist_r = construct_MOSES_mesh(meantracks_r, dist_thresh=1.2, spixel_size=spixel_size)
MOSES_mesh_strain_time_g, MOSES_mesh_neighborlist_g = construct_MOSES_mesh(meantracks_g, dist_thresh=1.2, spixel_size=spixel_size)

# Construct the radial neighbours mesh 
from MOSES.Motion_Analysis.mesh_statistics_tools import construct_radial_neighbour_mesh

radial_mesh_strain_time_r, radial_neighbourlist_time_r = construct_radial_neighbour_mesh(meantracks_r, dist_thresh=1.2, spixel_size=spixel_size, use_counts=False)
radial_mesh_strain_time_g, radial_neighbourlist_time_g = construct_radial_neighbour_mesh(meantracks_g, dist_thresh=1.2, spixel_size=spixel_size, use_counts=False)

# Construct the K nearest neighbour mesh. 
from MOSES.Motion_Analysis.mesh_statistics_tools import construct_knn_neighbour_mesh

knn_mesh_strain_time_r, knn_neighbourlist_time_r = construct_knn_neighbour_mesh(meantracks_r, k=4)
knn_mesh_strain_time_g, knn_neighbourlist_time_g = construct_knn_neighbour_mesh(meantracks_g, k=4)

"""
4b. Visualise the MOSES mesh.
"""
# visualise the two meshes independently by exploiting networkx for a chosen frame .e.g. 20, first turn the neighborlist into a networkx graph object.
# use networkx to visualize the mesh at frame 20. To do so need to first convert the neighborlist into a networkx graph object
from MOSES.Visualisation_Tools.mesh_visualisation import visualise_mesh
from MOSES.Motion_Analysis.mesh_statistics_tools import from_neighbor_list_to_graph

# use networkx to visualize the mesh at frame 20. To do so need to first convert the neighborlist into a networkx graph object.
mesh_frame20_networkx_G_red = from_neighbor_list_to_graph(meantracks_r, MOSES_mesh_neighborlist_r, 20)
mesh_frame20_networkx_G_green = from_neighbor_list_to_graph(meantracks_g, MOSES_mesh_neighborlist_g, 20)

# set up a 1 x 2 plotting canvas. 
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15))
# plot the red mesh in left panel. 
ax[0].imshow(vidstack[20], alpha=0.7)
visualise_mesh(mesh_frame20_networkx_G_red, meantracks_r[:,20,[1,0]], ax[0], node_size=20, linewidths=1, width=1, node_color='r')
ax[0].set_ylim([n_rows,0])
ax[0].set_xlim([0,n_cols])
ax[0].grid('off')
ax[0].axis('off')
# plot the green mesh in right panel.
ax[1].imshow(vidstack[20], alpha=0.7)
visualise_mesh(mesh_frame20_networkx_G_green, meantracks_g[:,20,[1,0]], ax[1], node_size=20, linewidths=1, width=1, node_color='g')
ax[1].set_ylim([n_rows,0])
ax[1].set_xlim([0,n_cols])
ax[1].grid('off')
ax[1].axis('off')
plt.show()    

# =============================================================================
# =============================================================================
# =============================================================================
# # #  Extracting Common MOSES measurements. 
# =============================================================================
# =============================================================================
# =============================================================================
"""
Construct the motion saliency map for each colour.
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_motion_saliency_map
final_saliency_map_r, spatial_time_saliency_map_r = compute_motion_saliency_map(meantracks_r, dist_thresh=5.*spixel_size, shape=(n_rows, n_cols), filt=1, filt_size=spixel_size)
final_saliency_map_g, spatial_time_saliency_map_g = compute_motion_saliency_map(meantracks_g, dist_thresh=5.*spixel_size, shape=(n_rows, n_cols), filt=1, filt_size=spixel_size)


# if also want to visualise the individual channel motion saliency maps. 
import numpy as np 
max_val = np.max([final_saliency_map_r.max(), 
              final_saliency_map_g.max()])

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15))
ax[0].imshow(final_saliency_map_r, cmap='coolwarm', vmin=0, vmax=max_val)
ax[1].imshow(final_saliency_map_g, cmap='coolwarm', vmin=0, vmax=max_val)
#fig.savefig('protocol_scripts/individual_motion_saliency_map_example.svg', dpi=300, bbox_inches='tight')
plt.show()


from MOSES.Motion_Analysis.mesh_statistics_tools import compute_boundary_formation_index
boundary_formation_index, av_saliency_map = compute_boundary_formation_index(final_saliency_map_r, final_saliency_map_g, spixel_size, pad_multiple=3)

fig = plt.figure()
plt.title('Boundary Formation Index %.3f' %(boundary_formation_index))
plt.imshow(av_saliency_map, cmap='coolwarm')
plt.axis('off')
plt.grid('off')
plt.show()


"""
Plot the MOSES mesh strain curve.
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_MOSES_mesh_strain_curve
import numpy as np 
# set normalise=True to obtain normalised curves between 0-1. 
mesh_strain_r = compute_MOSES_mesh_strain_curve(MOSES_mesh_strain_time_r, normalise=False)
mesh_strain_g = compute_MOSES_mesh_strain_curve(MOSES_mesh_strain_time_g, normalise=False)
# average the channel curves to get one curve for the video.
mesh_strain_curve_video = .5*(mesh_strain_r+mesh_strain_g)
# normalise the curves, this is equivalent to normalise=True above for a single channel.
normalised_mesh_strain_curve_video = mesh_strain_curve_video/ np.max(mesh_strain_curve_video)


# script to visualise the mesh strain curve. 
fig, ax = plt.subplots(figsize=(5,5))
ax.plot(mesh_strain_r, 'r', lw=3, label='red cells')
ax.plot(mesh_strain_g, 'g', lw=3, label='green cells')
ax.plot(mesh_strain_curve_video, 'k', lw=3, label='mean curve')
ax.fill_between([len(mesh_strain_r)-1-5, len(mesh_strain_r)-1], 
                 y1=[0,0], y2=[25,25], alpha=0.5)
plt.xlim([0,len(mesh_strain_r)-1])
plt.ylim([0,25])
plt.legend(loc='best')
plt.xlabel('Time [Frame No.]')
plt.ylabel('MOSES Mesh Strain [Pixels]')
plt.show()


"""
Compute the mesh stability index
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_MOSES_mesh_stability_index

# the mesh stability index uses an average over the last number of frame specified by the flag, last_frames. One should set this to approximately cover the plateau region of the strain curve. 
mesh_stability_index, normalised_mesh_strain_curve = compute_MOSES_mesh_stability_index(MOSES_mesh_strain_time_r, MOSES_mesh_strain_time_g, last_frames=5)
print('mesh stability index: %.3f' %(mesh_stability_index))

"""
Compute the mesh strain vector and visualise this. 
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import construct_mesh_strain_vector

# given a mesh as a list of neighbours (also known as a region adjacency list), find the mesh strain vector
mesh_strain_vector_r = construct_mesh_strain_vector(meantracks_r, [MOSES_mesh_neighborlist_r])
mesh_strain_vector_g = construct_mesh_strain_vector(meantracks_g, [MOSES_mesh_neighborlist_r])


# visualizing the mesh strain vector. visualize at time point 
# set up a 1 x 2 plotting canvas to visualize the mesh strain vector for each colour channel at frame 20. 
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15))
# plot the red mesh in left panel. 
ax[0].imshow(vidstack[20], alpha=0.5)
visualise_mesh(mesh_frame20_networkx_G_red, meantracks_r[:,20,[1,0]], ax[0], node_size=0, linewidths=10, width=1, node_color='r')
ax[0].quiver(meantracks_r[:,20,1], 
          meantracks_r[:,20,0], 
          mesh_strain_vector_r[20,:,1], 
          -mesh_strain_vector_r[20,:,0], color='r', scale_units='xy', width=0.005)
ax[0].set_ylim([n_rows,0])
ax[0].set_xlim([0,n_cols])
ax[0].grid('off')
ax[0].axis('off')
# plot the green mesh in right panel.
ax[1].imshow(vidstack[20], alpha=0.5)
visualise_mesh(mesh_frame20_networkx_G_green, meantracks_g[:,20,[1,0]], ax[1], node_size=0, linewidths=10, width=1, node_color='g')
ax[1].quiver(meantracks_g[:,20,1], 
          meantracks_g[:,20,0], 
          mesh_strain_vector_g[20,:,1], 
          -mesh_strain_vector_g[20,:,0], color='g', scale_units='xy', width=0.005)
ax[1].set_ylim([n_rows,0])
ax[1].set_xlim([0,n_cols])
ax[1].grid('off')
ax[1].axis('off')
plt.show()    


"""
Compute the MOSES mesh order using the computed mesh strain vector.
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_mesh_order

mesh_order_curve_r = compute_mesh_order(mesh_strain_vector_r, remove_mean=False)
mesh_order_curve_g = compute_mesh_order(mesh_strain_vector_g, remove_mean=False) 

av_mesh_order = .5*(mesh_order_curve_r + mesh_order_curve_g)
print('average mesh order: %.5f' %(np.nanmean(av_mesh_order)))

# plot the mesh order curve + plot the mean value as a dashed line. 
fig, ax = plt.subplots(figsize=(5,5))
ax.plot(mesh_order_curve_r, 'r', lw=3, label='red cells')
ax.plot(mesh_order_curve_g, 'g', lw=3, label='green cells')
ax.plot(av_mesh_order, 'k', lw=3, label='average')
ax.plot([0, len(av_mesh_order)-1], 
         [np.nanmean(av_mesh_order), np.nanmean(av_mesh_order)], 'k--', lw=3)
plt.xlim([0,len(mesh_order_curve_r)-1])
plt.ylim([0,0.015])
plt.legend(loc='best')
plt.xlabel('Time [h]')
plt.ylabel('MOSES Mesh Order')
#fig.savefig('protocol_scripts/MOSES_mesh_order_curve_example.svg', dpi=300, bbox_inches='tight')
plt.show()

